import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from transformers import GPTJForCausalLM, GPT2Tokenizer
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer
from transformers import set_seed
# from transformers import GPT2Tokenizer, OPTForCausalLM
import json
import argparse
import random
import pickle
from tqdm import tqdm
from data_utils import get_masked_edits, process_mquake_remastered_cf_6334
import pandas as pd
from huggingface_hub import login
from vllm import LLM, SamplingParams


def icl_lm_eval(model, sampling_params, icl_example, target, ice_example):
    with torch.no_grad():
        post_edit_tokens = model.generate([icl_example], sampling_params, use_tqdm=False) # 
        post_edit_tokens = post_edit_tokens[0].outputs[0].text.strip()
        print('\nInput: ', ice_example, '\nTarget: ', target, '\n\nOutput: ', post_edit_tokens)
        print("==" * 50)
    return post_edit_tokens

    
def ice_eval_loop(mquake_dataset, edited_caseid, model, sampling_params, result_file_path, masking, dataset_name):
    dataset = mquake_dataset.get_dataset()
    rand_list = mquake_dataset.get_randlist()
    
    if dataset_name != 'CF-6334' and masking:
        new_facts = None
    else:
        new_facts = set()
        
        for d in mquake_dataset.get_dataset():
            if d['case_id'] not in rand_list:
                continue
            for r in d["requested_rewrite"]:
                new_facts.add(f'{r["prompt"].format(r["subject"])} {r["target_new"]["str"]}')
        new_facts = list(new_facts)
        if not new_facts:
            new_facts = ["No relevant fact."]
    
    
    raw_answer_dict = {}
    for i in tqdm(range(len(dataset))):
        d = dataset[i]
        raw_answer_dict[d['case_id']] = {'edited': d['case_id'] in edited_caseid,
                                         'answers': []}
        
        if dataset_name != 'CF-6334' and masking:
            new_facts, _, _, _ = mquake_dataset.get_edits_without_contamination(rand_list, d)
            if not new_facts:
                new_facts = ["No relevant fact."]
        
        
        if d['case_id'] not in edited_caseid:
            target = dataset[i]['answer']
        else:
            target = dataset[i]['new_answer']

        for prompt_question in dataset[i]['questions']:
            ice_example = ' '.join(new_facts) + ' ' + prompt_question + ' Answer concisely.'
            model_output = icl_lm_eval(model, sampling_params, ice_example, target, ice_example)
            raw_answer_dict[d['case_id']]['answers'].append(model_output)

    with open(result_file_path, 'w') as fp:
        json.dump(raw_answer_dict, fp, indent = 4) 